{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from cot import Collection\n",
    "from cot.stats import evaluation_as_table"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### generate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[nltk_data] Downloading package punkt to /home/kon/nltk_data...\n",
      "[nltk_data]   Package punkt is already up-to-date!\n"
     ]
    }
   ],
   "source": [
    "coll = Collection.from_json(\"ts_67_empty\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Configuration of the input and parameters of the language model \n",
    "config={\n",
    "    \"instruction_keys\": None,\n",
    "    \"cot_trigger_keys\": [None, \"kojima-01\", \"zhou-01\"],\n",
    "    \"answer_extraction_keys\": 'auto-kojima', \n",
    "    \"author\" : \"thoughtsource\",\n",
    "    \"api_service\": \"openai_chat\",\n",
    "    \"api_time_interval\": 1,\n",
    "    \"engine\": \"gpt-3.5-turbo\", \n",
    "    \"temperature\": 0,\n",
    "    \"max_tokens\": 512,\n",
    "    \"verbose\": False,\n",
    "    \"warn\": False,\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating commonsense_qa...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8c120027fe5f4620a321c417d8560b30",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/67 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating med_qa...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c868e22eb6744842a849ab1b35eea486",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/67 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating medmc_qa...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ec9d5e08d11f4fc0bee73e929552df58",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/67 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating open_book_qa...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "474d996a5254422d81e585ee6a85acf3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/67 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating strategy_qa...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f2b09dc8e6ee456cb323d8f9d2c7549d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/67 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Retrying langchain.chat_models.openai.ChatOpenAI.completion_with_retry.<locals>._completion_with_retry in 4.0 seconds as it raised RateLimitError: That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 418889520aa5072cdee3f824e99a0e9d in your message.).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating worldtree...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e70550a9e13c41cf80c202a61b394cb8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/67 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "coll.generate(config=config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "546ece73eb9c49f5a045e9b9cd439072",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/67 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6461e0730ec74023910576aaa5b4a92a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/67 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "56e99e7cb3bb424c853a294a284abf1a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/67 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4a13da22ad62402a9f2798c6c079a7b3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/67 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e022b753a8634df2ab0dd386acf02f83",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/67 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6ed4118a876942eba83833d6f3e6e5be",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/67 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "{'commonsense_qa': {'validation': {'accuracy': {'gpt-3.5-turbo': {'None_None_kojima-A-E': 0.761194,\n",
       "     'None_kojima-01_kojima-A-E': 0.716418,\n",
       "     'None_zhou-01_kojima-A-E': 0.776119}}}},\n",
       " 'med_qa': {'test': {'accuracy': {'gpt-3.5-turbo': {'None_None_kojima-A-E': 0.522388,\n",
       "     'None_kojima-01_kojima-A-E': 0.567164,\n",
       "     'None_zhou-01_kojima-A-E': 0.567164}}}},\n",
       " 'medmc_qa': {'validation': {'accuracy': {'gpt-3.5-turbo': {'None_None_kojima-A-D': 0.537313,\n",
       "     'None_kojima-01_kojima-A-D': 0.522388,\n",
       "     'None_zhou-01_kojima-A-D': 0.507463}}}},\n",
       " 'open_book_qa': {'test': {'accuracy': {'gpt-3.5-turbo': {'None_None_kojima-A-D': 0.791045,\n",
       "     'None_kojima-01_kojima-A-D': 0.80597,\n",
       "     'None_zhou-01_kojima-A-D': 0.701493}}}},\n",
       " 'strategy_qa': {'train': {'accuracy': {'gpt-3.5-turbo': {'None_None_kojima-yes-no': 0.61194,\n",
       "     'None_kojima-01_kojima-yes-no': 0.626866,\n",
       "     'None_zhou-01_kojima-yes-no': 0.641791}}}},\n",
       " 'worldtree': {'test': {'accuracy': {'gpt-3.5-turbo': {'None_None_kojima-A-D': 0.955224,\n",
       "     'None_kojima-01_kojima-A-D': 0.895522,\n",
       "     'None_zhou-01_kojima-A-D': 0.895522}}}}}"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "coll.evaluate()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# join strings from a list with underscore\n",
    "def join_strings(list_of_strings):\n",
    "    if list_of_strings is None:\n",
    "        list_of_strings = [\"None\"]\n",
    "    if isinstance(list_of_strings, str):\n",
    "        list_of_strings = [list_of_strings]\n",
    "    joined_string = \"\"\n",
    "    for string in list_of_strings:\n",
    "        joined_string += (\"-\" + str(string))\n",
    "    # delete first underscore\n",
    "    joined_string = joined_string[1:]\n",
    "    return(joined_string)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e7088a6e41a84068ba9c33c61b8ae8b9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e33e1f5a93fc4360aaba7758ba61464f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "59bdf2189cc745b598bcdd85aeedd50b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b04181827a5d429a9d0511d72008f4e3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f68801209f624da8be02f95391bfa268",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f42b777a31714119ba481babcebe6eb4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "coll.dump(\"thoughtsource_67\" + \"_\" + config['api_service'] + \"_\" + config['engine'].replace(\"/\", \"_\") + \"_\" + join_strings(config[\"instruction_keys\"]) + \"_\" + join_strings(config[\"cot_trigger_keys\"]) + \".json\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### merge into"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "ts_67 = Collection.from_json(\"thoughtsource_67.json\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a1944587a3b9432690a18ab324d50261",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "556b598370594908967396f3af6adfd6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "bcaf696d35ca419bb16d2b0a54d39e95",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "062a09c8ece14d19bb3eec3dc2cc5b06",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "dab50687bf6b4d11a968690a07b664ce",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a8484aeef6a24958adb461d791915a2b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e33316048f6540ff9a058803a5c0bea2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "bfeb664d21c545979e83f2ea03a4f86d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "cbbbcbb9944b488c97fb98393a7b3b86",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3e75e870578d4706ae28585ccf558a3d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "895da66b75034d948c9928764037bbf9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "922cceac55094cafbf9a08a2d6b87d29",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "ts_67 = ts_67.merge(coll)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "67c4bbc5576e46d8b003db16565de96f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6e4ac927473341d9928187e45b675076",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e0fcbf8b68fc47cc9ce8778640f83038",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3a5f575937104334979c6f59dc82bc1a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7b60f1290c7140058a176c6d1c97d2a4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "88c1b9ea69264f76ab6c2d7fe4642fc2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "ts_67.dump(\"thoughtsource_67.json\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### evaluate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "coll = Collection.from_json(\"thoughtsource_67.json\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# coll.select_generated_cots(author='thoughtsource', model=[\"gpt-3.5-turbo\", \"flan-T5-xxl\"], cot_trigger=None, instruction=None)\n",
    "coll.select_generated_cots(author='thoughtsource')\n",
    "coll.select_generated_cots(model='gpt-3.5-turbo')\n",
    "# coll.select_generated_cots(cot_trigger=[None,'zhou-01'])\n",
    "# coll.select_generated_cots(instruction=[None])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eval = coll.evaluate()\n",
    "table = evaluation_as_table(eval)\n",
    "table"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "a19cb823e24c98ee05a9cfa4a3a579b5d56d4c1a735f2a12456750b95a1e155e"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
